import torch
from torch_geometric.data import Data
from instance_genetrator import * 

class Parallel_machine_tardiness():
    def __init__(self, env_params):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.env_params = env_params
        self.batch_size = self.env_params['batch_size'] 
        self.batch = self.env_params['batch_size'] 
        self.pomo = self.env_params['pomo_size'] 
        self.job_num = self.env_params['job_num']
        self.machine_num = self.env_params['machine_num']
        self.process_time_params = self.env_params['process_time_params']
        
    def _reset(self, batch, pomo):
        # problem generate 
        if pomo == -1 : 
            self.problems_INT_dict = batch
            self.batch_size = len(batch['p'])
            self.machine_num = len(batch['p'][0])
            self.job_num = len(batch['p'][0][0])

        else:
            self.job_num = self.env_params['job_num']
            self.machine_num = self.env_params['machine_num']

            if self.env_params['mode'] == "rand":
                self.problems_INT_dict =  generate_data_batch_rand( batch, pomo, self.machine_num, self.job_num, self.process_time_params) # self.problems_INT_dict = generate_data_batch(batch, self.machine_num, self.job_num, s_max, T, R, m_p)
            else:
                self.problems_INT_dict = generate_data_batch(batch, pomo, self.machine_num, self.job_num, self.process_time_params) # self.problems_INT_dict = generate_data_batch(batch, self.machine_num, self.job_num, s_max, T, R, m_p)            
            self.batch_size = batch*pomo

        self.BATCH_IDX = torch.arange(self.batch_size)
        # static feature  
        self.job_durations = self.problems_INT_dict['p']
        self.job_durations = torch.tensor(self.job_durations).float()

        self.job_setup = self.problems_INT_dict['setup_times']
        self.job_setup = torch.tensor(self.job_setup).float()
        
        self.initial_setup = self.problems_INT_dict['initial_setup']
        self.initial_setup = torch.tensor(self.initial_setup).float()

        self.job_setup_with_init = torch.cat([self.job_setup, self.initial_setup.reshape(self.batch_size, self.machine_num, 1, self.job_num)], 2)
        self.job_durations_setup_with_init = self.job_setup_with_init + self.job_durations.reshape(self.batch_size, self.machine_num, 1, self.job_num)

        self.job_weight = self.problems_INT_dict['w']
        self.job_weight = torch.tensor(self.job_weight).float()

        self.job_release = self.problems_INT_dict['release_times']
        self.job_release = torch.tensor(self.job_release).float()

        self.job_arrived_time = self.problems_INT_dict['release_times']
        self.job_arrived_time = torch.tensor(self.job_arrived_time).float()
        self.job_arrived_time = torch.zeros(self.batch_size, self.job_num)

        self.job_due = self.problems_INT_dict['due_dates']
        self.job_due = torch.tensor(self.job_due).float()
        
        self.machine_eligibility = self.problems_INT_dict['machine_eligibility']
        self.machine_eligibility_reshaped = torch.zeros(self.batch_size, self.machine_num, self.job_num)
        for i in range(len(self.machine_eligibility)):
            for j in range(len(self.machine_eligibility[i])):
                for k in range(len(self.machine_eligibility[i][j])):
                    self.machine_eligibility_reshaped[i][self.machine_eligibility[i][j][k]][j] = 1

        # dyanamic feature
        self.job_arrived = torch.zeros(self.batch_size, self.job_num) # scheduled jobs 
        self.job_scheduled = torch.zeros(self.batch_size, self.job_num) # scheduled jobs 
        self.job_not_scheduled = 1 - self.job_scheduled  # available jobs
        self.machine_remaining_time = torch.zeros(self.batch_size, self.machine_num) # machine jobs 
        self.simulation_time = torch.zeros(self.batch_size)
        self.machine_last_machine = -torch.ones(self.batch_size, self.machine_num).long()

        # dyanamic batch control 
        self.decision_available = torch.zeros(self.batch_size) 
        self.finished = torch.zeros(self.batch_size)

        # schedule info
        self.start_time = torch.zeros(self.batch_size, self.job_num) # start time of each job  
        self.end_time  = torch.zeros(self.batch_size, self.job_num) # end time of each job  
        self.job_machine_location  = torch.zeros(self.batch_size, self.job_num) # job location 
        
        self._move_to_next_machine()
        
    def _step(self, action):
        # job_idx.shape: (batch, pomo)
        machine_idx = torch.zeros(self.batch_size)
        job_idx = torch.zeros(self.batch_size)
        for batch in range(self.batch_size):
            a_tmp_1 = action[batch]//len(self.job_indices[batch]) # machine  
            a_tmp_2 = action[batch]%len(self.job_indices[batch]) # job
            job_idx[batch] = self.job_indices[batch][a_tmp_2]
            machine_idx[batch] = a_tmp_1
        job_idx, machine_idx =job_idx.long(), machine_idx.long()  
        
        # schedule info update 
        self.start_time[self.BATCH_IDX, job_idx] = torch.max(self.job_release[self.BATCH_IDX, job_idx], self.simulation_time ) 
        last_job = self.machine_last_machine[self.BATCH_IDX, machine_idx]
        job_length = self.job_durations_setup_with_init[self.BATCH_IDX, machine_idx, last_job, job_idx] 
        self.end_time[self.BATCH_IDX, job_idx] =  self.start_time[self.BATCH_IDX, job_idx] + job_length
        self.job_machine_location[self.BATCH_IDX, job_idx] = machine_idx.float()

        # feature update
        self.machine_last_machine[self.BATCH_IDX, machine_idx.long()] = job_idx.long()
        self.machine_remaining_time[self.BATCH_IDX, machine_idx.long()] = self.end_time[self.BATCH_IDX, job_idx] - self.simulation_time

        self.job_scheduled[self.BATCH_IDX, job_idx] = 1

        # done update
        done = self.job_scheduled.all()
        if done:
            reward = -self._get_obj()  # Note the MINUS Sign ==> We want to MAXIMIZE reward
        else:
            reward = None
            self._move_to_next_machine()
            self.state = self._get_state()
        return self.state, reward, done
    
    def _move_to_next_machine(self, ):
        # dynamic machine-job pair
        time_step = torch.min(self.machine_remaining_time, 1).values.reshape(self.batch_size)
        self.simulation_time += time_step # simulator update 
        self.machine_remaining_time -= time_step.reshape(self.batch_size, 1)

        while True:
            # machine available check - idle time zero  
            machine_available = (self.machine_remaining_time == 0).float()
            machine_available = machine_available.unsqueeze(2).expand(-1, -1, self.job_num)

            # job_not_scheduled check - to be scheduled 
            job_not_scheduled_1 = (self.job_scheduled == 0).float()
            job_not_scheduled = job_not_scheduled_1.unsqueeze(1).expand(-1, self.machine_num, -1)

            # job_arrived check - arrived jobs  
            job_arrived_1 = (self.job_arrived_time <= self.simulation_time.reshape(self.batch_size, 1)).float() # job_arrived jobs 
            job_arrived = job_arrived_1.unsqueeze(1).expand(-1, self.machine_num, -1)

            # calculate decision_available pairs 
            self.available_jobs = job_arrived_1*job_not_scheduled_1
            self.job_indices = [torch.nonzero(self.available_jobs[batch] == 1, as_tuple=True)[0] for batch in range(self.available_jobs.shape[0])]
            
            self.decision_available = self.machine_eligibility_reshaped * machine_available * job_not_scheduled * job_arrived
            decision_available_batch = torch.sum(torch.sum(self.decision_available, 1), 1)

            decision_available_batch += self.job_scheduled.all(dim = 1).float()

            if min(decision_available_batch) != 0:
                break 
 
            self.job_arrived = (self.job_arrived_time <= self.simulation_time.reshape(self.batch_size, -1)).float()

            time_step = torch.zeros(self.batch_size)
            for i in range(self.batch_size):
                if decision_available_batch[i] > 0:
                    continue
                if len(self.job_arrived[i][self.job_arrived[i] != 1]) != 0:  # arrrived 되지 않은 작업이 있다면 
                    min_values1 = self.job_arrived_time[i][self.job_arrived[i] != 1].min().item()
                else:
                    min_values1 = 1e9
                if len(self.machine_remaining_time[i][self.machine_remaining_time[i] != 0] )!= 0:
                    min_values2 = self.machine_remaining_time[i][self.machine_remaining_time[i] != 0].min().item()
                else:
                    min_values2 = 1e9
                time_step[i] = min(min_values1, min_values2)
            self.simulation_time += time_step # simulator update 
            self.machine_remaining_time -= time_step.reshape(self.batch_size, 1) # 0으로 cliping
            self.machine_remaining_time = torch.clamp(self.machine_remaining_time, min=0)
    
    def _get_state(self):
        # static
        self.job_feature = torch.zeros(self.batch_size, self.job_num, 3)
        self.machine_feature = torch.zeros(self.batch_size, self.machine_num, 2)
        self.machine_node_feature = torch.zeros(self.batch_size, self.machine_num, 5)

        self.job_feature[:,:,0] = self.job_due - self.simulation_time.reshape(self.batch_size, 1)
        self.job_feature[:,:,1] = self.job_weight
        self.job_feature[:,:,2] = torch.clamp(self.job_release - self.simulation_time.reshape(self.batch_size, 1), min=0)
        self.machine_feature[:,:,0] = self.machine_remaining_time

        state = []
        for batch in range(self.batch_size):
            job_indices = self.job_indices[batch]
            J = len(job_indices)
            M = self.machine_num

            machine_indices = torch.arange(M)
            job_feature = self.job_feature[batch, job_indices ,:]          # (J, 3)
            machine_feature = self.machine_feature[batch, machine_indices,:] # (M, 2)

            # job-machine pair node
            job_feature_expanded = job_feature.unsqueeze(0).expand(M, J, 3) # (M, J, 3)
            machine_feature_expanded = machine_feature.unsqueeze(1).expand(M, J, 2) # (M, J, 2)

            job_machine_pair_int = torch.cat((machine_feature_expanded, job_feature_expanded), dim=-1) # (M,J,5)
            job_machine_pair_int = job_machine_pair_int.reshape(M*J, -1) # (M*J, 5)

            eligible = self.machine_eligibility_reshaped[batch][:, job_indices].reshape(-1)  # (M*J,)
            eligible_expanded = eligible.unsqueeze(-1) # (M*J, 1)
            job_machine_pair_int = torch.cat([job_machine_pair_int, eligible_expanded], dim=-1) # (M*J, 6)

            machine_node_feature = self.machine_node_feature[batch, :] # (M,5)
            machine_node_feature_padded = torch.cat([machine_node_feature, 
                                                    torch.zeros((M,1), dtype=machine_node_feature.dtype)], dim=-1) # (M,6)

            x = torch.cat([job_machine_pair_int, machine_node_feature_padded], 0)

            edge1, edge_1_feature, edge2, edge3, edge_3_feature = self._get_edge(batch, job_indices)
            mask = self.decision_available[batch][:, job_indices].reshape(-1)
            mask = torch.cat([mask, torch.zeros(M)])

            state.append(Data(
                x = x, 
                edge_index_1 = edge1, 
                edge_1_feature = edge_1_feature, 
                edge_index_2 = edge2, 
                edge_index_3 = edge3, 
                edge_3_feature = edge_3_feature,
                mask = mask
            ))

        return state


    def _get_edge(self, batch, job_indices):
        job_indices = job_indices.long()
        J = len(job_indices)
        M = self.machine_num

        # setup_time: (M, J, J)
        setup_time = self.job_durations_setup_with_init[batch][:, job_indices, :][:, :, job_indices]
        # machine_setup_time: (M, J)
        machine_setup_time = self.job_durations_setup_with_init[batch][torch.arange(M), self.machine_last_machine[batch], :][:, job_indices]

        elig_matrix = self.machine_eligibility_reshaped[batch][:, job_indices]

        inelig_i = (elig_matrix == 0).unsqueeze(2).expand(M, J, J)
        inelig_j = (elig_matrix == 0).unsqueeze(1).expand(M, J, J)
        setup_time[inelig_i | inelig_j] = -1
        machine_setup_time[elig_matrix == 0] = -1

        m_ar = torch.arange(M)
        i_ar = torch.arange(J)
        j_ar = torch.arange(J)
        m_grid, i_grid, j_grid = torch.meshgrid(m_ar, i_ar, j_ar, indexing='ij')
        edge1_0 = (m_grid * J + i_grid).flatten()
        edge1_1 = (m_grid * J + j_grid).flatten()
        edge1 = torch.stack([edge1_0, edge1_1], dim=0).long()

        j_grid2, m1_grid, m2_grid = torch.meshgrid(j_ar, m_ar, m_ar, indexing='ij')
        edge2_0 = (m1_grid * J + j_grid2).flatten()
        edge2_1 = (m2_grid * J + j_grid2).flatten()
        edge2 = torch.stack([edge2_0, edge2_1], dim=0).long()

        m_grid3, j_grid3 = torch.meshgrid(m_ar, i_ar, indexing='ij')
        machine_nodes = (M*J + m_grid3.flatten())   # M*J+m
        job_nodes = (m_grid3.flatten() * J + j_grid3.flatten())  # m*J+j
        total_pairs = machine_nodes.numel()

        edge3_0_full = torch.empty(2*total_pairs, dtype=torch.long)
        edge3_1_full = torch.empty(2*total_pairs, dtype=torch.long)

        edge3_0_full[0::2] = machine_nodes 
        edge3_0_full[1::2] = job_nodes   
        edge3_1_full[0::2] = job_nodes     
        edge3_1_full[1::2] = machine_nodes  

        edge3 = torch.stack([edge3_0_full, edge3_1_full], dim=0).long()
        setup_time = setup_time.reshape(-1).float()

        return edge1, setup_time, edge2, edge3, machine_setup_time.repeat_interleave(2)

    def _get_obj(self):
        tardiness = torch.clip(self.end_time - self.job_due, min = 0)
        return torch.sum(self.job_weight*tardiness, 1)